#include "statsxx/machine_learning/NeuralNet.hpp"

// STL
#include <algorithm>    // std::sort
#include <cmath>        // sqrt
#include <iostream>     // std::cout
#include <limits>       // std::numeric_limits<>
#include <utility>      // std::pair

// jSCIENCE
#include "jScience/stl/vector.hpp" // magnitude()
#include "jutility.hpp" // get_n_rand_unique_elements()
#include "jstats.hpp"   // mean(), stddev()

// stats++
#include "statsxx/dataset.hpp"     // DataSet, partition_dataset()


//
// DESC: Trains the neural network using the training sets TS_tr and TS_val (the latter for stopping), returning the generalization error on TS_gen.
//
// INPUT:
//          imethod == 0 :: Backpropagation
//          imethod == 1 :: QRprop
//          imethod == 2 :: CG
//          imethod == 3 :: SCG
//
// TODO: NOTE: It might be worthwhile to setup structs/classes with the training parameters (parent --- abstract, inherited --- specific) that can be passed to this subroutine.
//
// TODO: Should update the lr parameters to be per point --- this would make simulations consistent with different size datasets.
//
inline double NEURAL_NET::train(
                                const int      method,
                                //------
                                const int      nepoch_min,
                                const int      nepoch_max,
                                // -----
                                const bool     early_stopping,
                                // -----
                                const double   lr,
                                const double   lr_min,
                                const double   lr_max,
                                // -----
                                const double   momentum,
                                // -----
                                const double   weight_penalty,
                                // -----
                                const double   qrprop_u,
                                const double   qrprop_d,
                                // -----
                                const double   scg_lambda,
                                const double   scg_sigma,
                                const double   scg_convg_iterfrac,
                                const double   scg_rk_tol,
                                // -----
                                const DataSet &TS_tr,
                                const DataSet &TS_val,
                                const DataSet &TS_gen,
                                // -----
                                const int      npts_per_batch,
                                // -----
                                const bool     silent
                                )
{
    switch(method)
    {
        case 0:
            this->backpropagation(
                                  nepoch_max,
                                  // -----
                                  early_stopping,
                                  // -----
                                  lr,
                                  // -----
                                  momentum,
                                  // -----
                                  weight_penalty,
                                  // -----
                                  TS_tr,
                                  TS_val,
                                  // -----
                                  npts_per_batch
                                  );
            break;
        case 1:
            this->QRprop(
                         nepoch_max,
                         // -----
                         early_stopping,
                         // -----
                         lr,
                         lr_min,
                         lr_max,
                         // -----
                         qrprop_u,
                         qrprop_d,
                         // -----
                         TS_tr,
                         TS_val
                         );
            break;
        case 2:
            this->CG(
                     nepoch_max,
                     // -----
                     early_stopping,
                     // -----
                     TS_tr,
                     TS_val
                     );
            break;
        case 3:
            this->SCG(
                      nepoch_min,
                      nepoch_max,
                      // -----
                      early_stopping,
                      // -----
                      scg_lambda,
                      scg_sigma,
                      // -----
                      TS_tr,
                      TS_val,
                      // -----
                      scg_convg_iterfrac,
                      scg_rk_tol,
                      // -----
                      silent
                      );
            break;
        default:
            std::cout << "Error in NEURAL_NET::train(): Training method (imethod) not recognized!" << std::endl;
            exit(0);
    }

    std::vector<std::vector<double>> NN_out;
    double Emean;
    double Estddev;
    std::tie( Emean, Estddev, NN_out) = parse_TS(TS_gen);

    return Emean;
}
